//Package privilege 用户帐号及权限检查
//流程
//1. 创建一个对像,传入多个实例的配置
//2. 从多个实例读取用户信息/DB信息；默认为拥有权限
//3. 传入用户名/主机IP/DB进行帐号认证及权限检查
package privilege

import (
	"bytes"
	"crypto/sha1"
	"database/sql"
	"encoding/hex"
	"fmt"
	"strconv"
	"strings"
	"sync"

	"github.com/pingcap/tidb/mysql"

	_ "github.com/go-sql-driver/mysql" //mysql driver
	"github.com/zeast/logs"
)

var _ Checker = (*UserChecker)(nil)

//TODO reload maybe design to a api will best, only add user or change privilege can the api.
const reloadTick = 600 //10 min

const (
	sysDB    = "mysql"
	userTbl  = "user"
	priveTbl = "db"
)

//AuthChecker auth checker
var AuthChecker *UserChecker

//userPrivilege the maping of mysql.user table
type userPrivilege struct {
	user     string
	host     string
	password []byte
	priv     map[mysql.PrivilegeType]bool
	maxConn  int
}

//dbPrivilege the maping of mysql.db table
type dbPrivilege struct {
	db   string
	user string
	host string
	priv map[mysql.PrivilegeType]bool
}

//UserChecker check user auth and privilege
type UserChecker struct {
	sync.RWMutex
	authDSN string
	cb      Notifier

	//user name is static, but host need to match.
	users map[string]map[string]userPrivilege //user => [host]userPrivilege
	priv  map[string]map[string]dbPrivilege   //user => [host]dbPrivilege
}

//InitAuthChecker init the userauth struct.
func InitAuthChecker(dsn string) error {
	uc := UserChecker{
		authDSN: dsn,
		users:   make(map[string]map[string]userPrivilege),
		priv:    make(map[string]map[string]dbPrivilege),
	}
	err := uc.Reload()
	if err != nil {
		return err
	} else {
		AuthChecker = &uc
		return nil
	}

}

//Reload reload the user info, user privilege.
func (uc *UserChecker) Reload() error {
	logs.Info("Reloading User/Priv ...")
	if err := uc.load(); err != nil {
		logs.Errorf("用户权限载入失败. %s", err)
		return err
	}
	logs.Info("Load User/Priv Finish")
	return nil
}

//SetNofity set the callback when auth change.
func (uc *UserChecker) SetNofity(cb Notifier) {
	uc.cb = cb
}

//Auth auth user
func (uc *UserChecker) Auth(user string, salt []byte, auth []byte) (bool, error) {
	name, host, err := parse(user)
	if err != nil {
		logs.Error("用户认证格式错误. %s", err)
		return false, err
	}
	uc.RLock()
	defer uc.RUnlock()
	authes, ok := uc.users[name]
	if !ok {
		logs.Errorf("没有找到用户: %s", name)
		return false, mysql.NewErr(mysql.ErrAccessDenied, name, host, "YES")
	}
	for _, a := range authes {
		if match(host, a.host) {
			if a.validate(salt, auth) {
				return true, nil
			}
		}
	}
	return false, mysql.NewErr(mysql.ErrAccessDenied, name, host, "YES")
}

//HasPrivilege check the user have privilege
func (uc *UserChecker) HasPrivilege(user, db, table string, p mysql.PrivilegeType) (ok bool, err error) {
	name, host, err := parse(user)
	if err != nil {
		return
	}
	uc.RLock()
	defer uc.RUnlock()

	authes, ok1 := uc.users[name]
	if ok1 {
		for _, u := range authes {
			if match(host, u.host) {
				//AllPriv means any one of prives.
				if u.priv[p] || p == mysql.AllPriv {
					return true, nil
				}
			}
		}
	}
	prives, ok2 := uc.priv[name]
	if !ok2 {
		err = mysql.NewErr(mysql.ErrSpecificAccessDenied, mysql.Priv2Str[p])
		return
	}
	for _, a := range prives {
		// log.Debugf("HasPrivilege: host:%s - %s; %s - %s", host, a.host, db, a.db)
		if match(host, a.host) && match(db, a.db) {
			// log.Debugf("HasPrivilege Priv: %v, %+v", p, a.priv)
			if a.priv[p] || p == mysql.AllPriv {
				return true, nil
			}
			err = mysql.NewErr(mysql.ErrSpecificAccessDenied, mysql.Priv2Str[p])
			return
		}
	}
	err = mysql.NewErr(mysql.ErrDBaccessDenied, name, host, db)

	return
}

//UserList list the users with  maxConn: user => maxUserConn
func (uc *UserChecker) UserList() map[string]int {
	maxConns := make(map[string]int)
	uc.RLock()
	for user, authes := range uc.users {
		for _, a := range authes {
			maxConns[user] += a.maxConn
		}
	}
	uc.RUnlock()
	return maxConns
}

func (uc *UserChecker) load() error {
	newUsers := make(map[string]map[string]userPrivilege)
	newPriv := make(map[string]map[string]dbPrivilege)
	notify := make([]string, 0, 2)

	users, err := uc.getUserPriv()
	if err != nil {
		return err
	}

	dbs, err := uc.getDBPriv()
	if err != nil {
		return err
	}

	for u, v := range users {
		if _, ok := newUsers[u]; ok { // user auth exsit, add userAuth
			for _host, _v := range v {
				newUsers[u][_host] = _v
			}
		} else {
			notify = append(notify, u)
			newUsers[u] = v
		}
	}
	for u, v := range dbs {
		if _, ok := newPriv[u]; ok { // user priv exsit, add userAuth
			for _host, _v := range v {
				newPriv[u][_host] = _v
			}
		} else {
			notify = append(notify, u)
			newPriv[u] = v
		}
	}

	//check new user add
	if uc.cb != nil {
		for _, n := range notify {
			uc.cb.NotifyNew(n, 0)
		}
	}

	uc.Lock()
	uc.users = newUsers
	uc.priv = newPriv
	uc.Unlock()

	//print the info.
	logs.Infof("%s", uc.String())
	return nil
}

func (uc *UserChecker) String() string {
	pad := func(s string) string {
		return fmt.Sprintf("%20s ", s)
	}
	var str = "User Privilege Info:\n\t\tUSER HOST\n"
	for k, v := range uc.users {
		str += pad(k)
		for _, _v := range v {
			str += _v.host + ", "
		}
		str += "\n"
	}
	str += "\n\t\tUSER PRIVILIEGE{HOST:DATABASE}\n"
	for k, v := range uc.priv {
		str += pad(k)
		for _, _v := range v {
			str += fmt.Sprintf("{%s : %s}, ", _v.host, _v.db)
		}
		str += "\n"
	}
	return str
}

func (ua *userPrivilege) validate(salt, auth []byte) bool {
	crypt := sha1.New()
	crypt.Write(salt)
	crypt.Write(ua.password)
	stage1 := crypt.Sum(nil)
	for i := range stage1 {
		stage1[i] ^= auth[i]
	}
	crypt.Reset()
	crypt.Write(stage1)
	if bytes.Equal(crypt.Sum(nil), ua.password) {
		return true
	}
	return false
}

//getAuthPriv load user auth data from database `mysql`.`user`
func (uc *UserChecker) getUserPriv() (map[string]map[string]userPrivilege, error) {
	db, err := sql.Open("mysql", uc.authDSN)
	if err != nil {
		return nil, err
	}
	defer db.Close()

	cols := make([]string, len(mysql.Priv2UserCol))
	k := 0
	for _, v := range mysql.Priv2UserCol {
		cols[k] = v
		k++
	}

	colStr := strings.Join(cols, ",")
	authSQL := fmt.Sprintf("SELECT User,Host,Password,max_user_connections,%s FROM %s WHERE User > '' AND Password > '';", colStr, userTbl)
	rows, err := db.Query(authSQL)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	columns, _ := rows.Columns()
	values := make([]sql.RawBytes, len(columns))
	scanArgs := make([]interface{}, len(columns))
	for i := range values {
		scanArgs[i] = &values[i]
	}

	authes := make(map[string]map[string]userPrivilege)
	for rows.Next() {
		if err := rows.Scan(scanArgs...); err != nil {
			logs.Debug(err)
			return nil, err
		}
		ua := userPrivilege{
			priv: make(map[mysql.PrivilegeType]bool),
		}

		for i, v := range values {
			col := columns[i]
			switch col {
			case "User":
				ua.user = string(v)
			case "Host":
				ua.host = string(v)
			case "Password":
				passwd := string(v)
				if len(passwd) > 0 && len(passwd) == 41 {
					ua.password, _ = hex.DecodeString(passwd[1:])
				} else {
					ua.password = v
				}
			case "max_user_connections":
				max, _ := strconv.ParseInt(string(v), 10, 32)
				ua.maxConn = int(max)
			default:
				if p, ok := mysql.Col2PrivType[col]; ok && string(v) == "Y" {
					ua.priv[p] = true
				}
			}
		}
		if _, ok := authes[ua.user]; !ok {
			authes[ua.user] = make(map[string]userPrivilege)
		}
		//user--host is key.
		authes[ua.user][ua.host] = ua
	}
	//fixed bug , when server kill the connection the rows.Next will be false. and get a part of result rows.
	if rows.Err() != nil {
		return nil, rows.Err()
	}
	return authes, nil
}

//getDBPriv load priv data from `mysql`.`db`
func (uc *UserChecker) getDBPriv() (map[string]map[string]dbPrivilege, error) {
	db, err := sql.Open("mysql", uc.authDSN)
	if err != nil {
		return nil, err
	}
	defer db.Close()

	privSQL := fmt.Sprintf(`SELECT * FROM %s WHERE User > '';`, priveTbl)
	rows, err := db.Query(privSQL)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	columns, _ := rows.Columns()
	values := make([]sql.RawBytes, len(columns))
	scanArgs := make([]interface{}, len(columns))
	for i := range values {
		scanArgs[i] = &values[i]
	}

	privs := make(map[string]map[string]dbPrivilege)
	for rows.Next() {
		if err := rows.Scan(scanArgs...); err != nil {
			return nil, err
		}
		priv := dbPrivilege{
			priv: make(map[mysql.PrivilegeType]bool),
		}
		for i, v := range values {
			col := columns[i]
			switch {
			case col == "Db":
				priv.db = string(v)
			case col == "Host":
				priv.host = string(v)
			case col == "User":
				priv.user = string(v)
			default:
				if p, ok := mysql.Col2PrivType[col]; ok && string(v) == "Y" {
					priv.priv[p] = true
				}
			}
		}
		if _, ok := privs[priv.user]; !ok {
			privs[priv.user] = make(map[string]dbPrivilege)
		}
		// make sure the key is unique. for per user, host + db is unique.
		privs[priv.user][priv.host+":"+priv.db] = priv
	}
	//fixed bug , when server kill the connection the rows.Next will be false. and get a part of result rows.
	if rows.Err() != nil {
		return nil, rows.Err()
	}

	return privs, nil
}
